{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Collecting experience...\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Know more, visit my Python tutorial page: https://morvanzhou.github.io/tutorials/\n",
    "My Youtube Channel: https://www.youtube.com/user/MorvanZhou\n",
    "More about Reinforcement learning: https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/\n",
    "Dependencies:\n",
    "tensorflow: 1.1.0\n",
    "matplotlib\n",
    "numpy\n",
    "gym: 0.8.1\n",
    "\"\"\"\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import gym\n",
    "\n",
    "tf.set_random_seed(1)\n",
    "np.random.seed(1)\n",
    "\n",
    "# Hyper Parameters\n",
    "BATCH_SIZE = 32\n",
    "LR = 0.01                   # learning rate\n",
    "EPSILON = 0.9               # greedy policy\n",
    "GAMMA = 0.9                 # reward discount\n",
    "TARGET_REPLACE_ITER = 100   # target update frequency\n",
    "MEMORY_CAPACITY = 2000\n",
    "MEMORY_COUNTER = 0          # for store experience\n",
    "LEARNING_STEP_COUNTER = 0   # for target updating\n",
    "env = gym.make('CartPole-v0')\n",
    "env = env.unwrapped\n",
    "N_ACTIONS = env.action_space.n\n",
    "N_STATES = env.observation_space.shape[0]\n",
    "MEMORY = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2))     # initialize memory\n",
    "\n",
    "# tf placeholders\n",
    "tf_s = tf.placeholder(tf.float32, [None, N_STATES])\n",
    "tf_a = tf.placeholder(tf.int32, [None, ])\n",
    "tf_r = tf.placeholder(tf.float32, [None, ])\n",
    "tf_s_ = tf.placeholder(tf.float32, [None, N_STATES])\n",
    "\n",
    "with tf.variable_scope('q'):        # evaluation network\n",
    "    l_eval = tf.layers.dense(tf_s, 10, tf.nn.relu, kernel_initializer=tf.random_normal_initializer(0, 0.1))\n",
    "    q = tf.layers.dense(l_eval, N_ACTIONS, kernel_initializer=tf.random_normal_initializer(0, 0.1))\n",
    "\n",
    "with tf.variable_scope('q_next'):   # target network, not to train\n",
    "    l_target = tf.layers.dense(tf_s_, 10, tf.nn.relu, trainable=False)\n",
    "    q_next = tf.layers.dense(l_target, N_ACTIONS, trainable=False)\n",
    "\n",
    "q_target = tf_r + GAMMA * tf.reduce_max(q_next, axis=1)                   # shape=(None, ),\n",
    "\n",
    "a_indices = tf.stack([tf.range(tf.shape(tf_a)[0], dtype=tf.int32), tf_a], axis=1)\n",
    "q_wrt_a = tf.gather_nd(params=q, indices=a_indices)     # shape=(None, ), q for current state\n",
    "\n",
    "loss = tf.reduce_mean(tf.squared_difference(q_target, q_wrt_a))\n",
    "train_op = tf.train.AdamOptimizer(LR).minimize(loss)\n",
    "\n",
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "\n",
    "def choose_action(s):\n",
    "    s = s[np.newaxis, :]\n",
    "    if np.random.uniform() < EPSILON:\n",
    "        # forward feed the observation and get q value for every actions\n",
    "        actions_value = sess.run(q, feed_dict={tf_s: s})\n",
    "        action = np.argmax(actions_value)\n",
    "    else:\n",
    "        action = np.random.randint(0, N_ACTIONS)\n",
    "    return action\n",
    "\n",
    "\n",
    "def store_transition(s, a, r, s_):\n",
    "    global MEMORY_COUNTER\n",
    "    transition = np.hstack((s, [a, r], s_))\n",
    "    # replace the old memory with new memory\n",
    "    index = MEMORY_COUNTER % MEMORY_CAPACITY\n",
    "    MEMORY[index, :] = transition\n",
    "    MEMORY_COUNTER += 1\n",
    "\n",
    "\n",
    "def learn():\n",
    "    # update target net\n",
    "    global LEARNING_STEP_COUNTER\n",
    "    if LEARNING_STEP_COUNTER % TARGET_REPLACE_ITER == 0:\n",
    "        t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='q_next')\n",
    "        e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='q')\n",
    "        sess.run([tf.assign(t, e) for t, e in zip(t_params, e_params)])\n",
    "    LEARNING_STEP_COUNTER += 1\n",
    "\n",
    "    # learning\n",
    "    sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)\n",
    "    b_memory = MEMORY[sample_index, :]\n",
    "    b_s = b_memory[:, :N_STATES]\n",
    "    b_a = b_memory[:, N_STATES].astype(int)\n",
    "    b_r = b_memory[:, N_STATES+1]\n",
    "    b_s_ = b_memory[:, -N_STATES:]\n",
    "    sess.run(train_op, {tf_s: b_s, tf_a: b_a, tf_r: b_r, tf_s_: b_s_})\n",
    "\n",
    "print('\\nCollecting experience...')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "ename": "NotImplementedError",
     "evalue": "abstract",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNotImplementedError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-2-65b11eddddfc>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      3\u001b[0m     \u001b[0mep_r\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m     \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m         \u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      6\u001b[0m         \u001b[0ma\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mchoose_action\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0ms\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      7\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\gym\\envs\\classic_control\\cartpole.py\u001b[0m in \u001b[0;36mrender\u001b[1;34m(self, mode)\u001b[0m\n\u001b[0;32m    149\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    150\u001b[0m             \u001b[1;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mclassic_control\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mrendering\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 151\u001b[1;33m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mrendering\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mViewer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mscreen_width\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mscreen_height\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    152\u001b[0m             \u001b[0ml\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mr\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mt\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m-\u001b[0m\u001b[0mcartwidth\u001b[0m\u001b[1;33m/\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcartwidth\u001b[0m\u001b[1;33m/\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcartheight\u001b[0m\u001b[1;33m/\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m-\u001b[0m\u001b[0mcartheight\u001b[0m\u001b[1;33m/\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    153\u001b[0m             \u001b[0maxleoffset\u001b[0m \u001b[1;33m=\u001b[0m\u001b[0mcartheight\u001b[0m\u001b[1;33m/\u001b[0m\u001b[1;36m4.0\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\gym\\envs\\classic_control\\rendering.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, width, height, display)\u001b[0m\n\u001b[0;32m     58\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwidth\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mwidth\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     59\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mheight\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mheight\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 60\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwindow\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpyglet\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mWindow\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mwidth\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mwidth\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mheight\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mheight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdisplay\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     61\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mon_close\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwindow_closed_by_user\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     62\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0misopen\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\pyglet\\window\\__init__.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, width, height, caption, resizable, style, fullscreen, visible, vsync, display, screen, config, context, mode)\u001b[0m\n\u001b[0;32m    502\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    503\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mscreen\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 504\u001b[1;33m             \u001b[0mscreen\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_default_screen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    505\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    506\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\pyglet\\canvas\\base.py\u001b[0m in \u001b[0;36mget_default_screen\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m     71\u001b[0m         \u001b[1;33m:\u001b[0m\u001b[0mrtype\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;32mclass\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mScreen\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     72\u001b[0m         '''\n\u001b[1;32m---> 73\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_screens\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     74\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     75\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mget_windows\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\pyglet\\canvas\\base.py\u001b[0m in \u001b[0;36mget_screens\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m     63\u001b[0m         \u001b[1;33m:\u001b[0m\u001b[0mrtype\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mlist\u001b[0m \u001b[0mof\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;32mclass\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mScreen\u001b[0m\u001b[0;31m`\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     64\u001b[0m         '''\n\u001b[1;32m---> 65\u001b[1;33m         \u001b[1;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'abstract'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     66\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     67\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mget_default_screen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNotImplementedError\u001b[0m: abstract"
     ]
    }
   ],
   "source": [
    "for i_episode in range(400):\n",
    "    s = env.reset()\n",
    "    ep_r = 0\n",
    "    while True:\n",
    "        env.render()\n",
    "        a = choose_action(s)\n",
    "\n",
    "        # take action\n",
    "        s_, r, done, info = env.step(a)\n",
    "\n",
    "        # modify the reward\n",
    "        x, x_dot, theta, theta_dot = s_\n",
    "        r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8\n",
    "        r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5\n",
    "        r = r1 + r2\n",
    "\n",
    "        store_transition(s, a, r, s_)\n",
    "\n",
    "        ep_r += r\n",
    "        if MEMORY_COUNTER > MEMORY_CAPACITY:\n",
    "            learn()\n",
    "            if done:\n",
    "                print('Ep: ', i_episode,\n",
    "                      '| Ep_r: ', round(ep_r, 2))\n",
    "\n",
    "        if done:\n",
    "            break\n",
    "        s = s_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
